import numpy as np
import torch
import torch.optim as optim
from sklearn.neighbors import kneighbors_graph
from typing import List
import os
import matplotlib.pyplot as plt

def fmt_sigma(s):
    return str(s).replace('.', 'p')

def generate_helix_given_X(n_samples, sigma, X):
    U = np.random.uniform(0, 2 * np.pi, size=(n_samples, 1))
    epsilon = np.random.normal(0, sigma, size=(n_samples, 2))

    Y1 = 2 * X + U * np.sin(2 * U) + epsilon[:, 0:1]
    Y2 = 2 * X + U * np.cos(2 * U) + epsilon[:, 1:2]

    Y = np.hstack((Y1, Y2))

    return Y

def generate_helix_data(n_samples, sigma):
    
    X = np.random.randn(n_samples, 1)
    U = np.random.uniform(0, 2 * np.pi, size=(n_samples, 1))
    epsilon = np.random.normal(0, sigma, size=(n_samples, 2))

    Y1 = 2 * X + U * np.sin(2 * U) + epsilon[:, 0:1]
    Y2 = 2 * X + U * np.cos(2 * U) + epsilon[:, 1:2]

    Y = np.hstack((Y1, Y2))
    return X, Y

def train_gcds(X_train, Y_train, eta_train, n_samples, Generator, Discriminator,
               seed=42, batch_size=200, epochs=2000, lr=1e-3, print_interval=50):
    
    torch.manual_seed(seed)
    G = Generator()
    D = Discriminator()

    # Setup optimizers for Generator and Discriminator
    optimizer_G = optim.Adam(G.parameters(), lr=lr)
    optimizer_D = optim.Adam(D.parameters(), lr=lr)

    for epoch in range(epochs):
        permutation = torch.randperm(n_samples)

        for i in range(0, n_samples, batch_size):
            indices = permutation[i:i+batch_size]
            X_batch = X_train[indices]
            Y_batch = Y_train[indices]
            eta_batch = eta_train[indices]

            # Generate fake samples using the generator with fixed eta
            Y_fake = G(X_batch, eta_batch)

            # ----- Train the Discriminator -----
            optimizer_D.zero_grad()
            real_output = D(X_batch, Y_batch)
            fake_output = D(X_batch, Y_fake.detach()) 
            loss_D = torch.mean(fake_output) - torch.mean(torch.exp(real_output))
            (-loss_D).backward()
            optimizer_D.step()

            # ----- Train the Generator -----
            optimizer_G.zero_grad()
            fake_output_for_G = D(X_batch, Y_fake)
            loss_G = torch.mean(fake_output_for_G)
            loss_G.backward()
            optimizer_G.step()

        if epoch % print_interval == 0 or epoch == epochs - 1:
            print(f"Epoch {epoch}/{epochs}: Loss_D={loss_D.item():.4f}, Loss_G={loss_G.item():.4f}")

    return G, D

def gaussian_kernel(A: torch.Tensor, B: torch.Tensor, sigma: float = 1) -> torch.Tensor:
    sigma = torch.tensor(sigma, dtype=torch.float32)
    diff = A.unsqueeze(1) - B.unsqueeze(0)
    return torch.exp(-sigma * torch.norm(diff, dim=2) ** 2)

def ECMMD(Z, Y, X, batch, kernel, neighbors: int, kernel_band = 1) -> torch.Tensor:

    N_X_sparse = kneighbors_graph(X.detach().cpu().numpy(), neighbors, include_self=False)
    N_X_sparse = N_X_sparse.tocoo()
    indices = torch.from_numpy(np.array([N_X_sparse.row, N_X_sparse.col])).long()
    values = torch.tensor(N_X_sparse.data, dtype=torch.float32)
    shape = torch.Size(N_X_sparse.shape)
    N_X = torch.sparse_coo_tensor(indices, values, shape)

    # Compute kernel matrices using the differentiable kernel function
    kernel_ZZ = kernel(Z, Z, sigma = kernel_band)
    kernel_YY = kernel(Y, Y, sigma = kernel_band)
    kernel_ZY = kernel(Z, Y, sigma = kernel_band)
    kernel_YZ = kernel(Y, Z, sigma = kernel_band)

    # Compute the H matrix (element-wise difference of kernels)
    H = kernel_ZZ + kernel_YY - kernel_ZY - kernel_YZ

    # Use the sparse indices to efficiently sum only over the neighbor entries
    sparse_indices = N_X._indices()
    total = torch.sum(H[sparse_indices[0], sparse_indices[1]])

    # Compute ECMMD loss by normalizing with the total number of neighbor entries.
    ECMMD_value = total / (batch * neighbors)
    
    return ECMMD_value

def train_ecmmd_model(X_train, Y_train, eta_train, n_samples, Generator,
                      seed=42, lr=1e-3, epochs=1000, batch_size=200, neighbors=15,
                      kernel_function=None, kernel_band=1, print_interval=50):
    
    torch.manual_seed(seed)
    model = Generator()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(epochs):
        # Create a random permutation of indices for stochastic mini-batch selection.
        permutation = torch.randperm(n_samples)
        
        for i in range(0, n_samples, batch_size):
            indices = permutation[i:i+batch_size]
            X_batch = X_train[indices]
            Y_batch = Y_train[indices]
            eta_batch = eta_train[indices]

            # Generate fake outputs using the model.
            Y_fake = model(X_batch, eta_batch)
            
            # ECMMD loss
            loss = ECMMD(Y_batch, Y_fake, X_batch, batch_size, kernel_function, neighbors, kernel_band=kernel_band)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        if epoch % print_interval == 0 or epoch == epochs - 1:
            print(f"Epoch {epoch}/{epochs}: Loss={loss.item():.4f}")

    return model

def plot_generated_and_test_samples(
    seed,
    G1,
    G2,
    noise_dim,
    sigma,
    generate_helix_given_X,
    x_eval=1,
    sample_size=1000,
    s=1,
    alpha=0.5,
    save_dir_fig="figures",
    save_dir_csv="generated_data",
    panel_basename="helix_comparison",
    save_csv=True,
    return_fig=False,
):

    os.makedirs(save_dir_fig, exist_ok=True)
    os.makedirs(save_dir_csv, exist_ok=True)

    if np.isscalar(sigma):
        sigmas = [float(sigma)]
    else:
        sigmas = [float(x) for x in sigma]

    nrows = len(sigmas)

    # --------- Figure and label styling parameters ----------
    FIG_WIDTH = 14           
    ROW_HEIGHT = 3.2         
    ROW_LABEL_FONTSIZE = 22  
    ROW_LABEL_OFFSET = 0.40  
    ROW_GAP_HSPACE = 0.95    
    BOTTOM_MARGIN = 0.12    
    # ---------------------------------------------------------------------

    def _get_model(M, sig):
        if isinstance(M, dict):
            return M[sig]
        if callable(M):
            return M(sig)
        return M

    
    fig, axes = plt.subplots(nrows=nrows, ncols=3, figsize=(FIG_WIDTH, ROW_HEIGHT * nrows))
    if nrows == 1:
        axes = np.array([axes]) 

    for r, sig in enumerate(sigmas):
        torch.manual_seed(seed + r)
        np.random.seed(seed + r)
        
        g1 = _get_model(G1, sig)
        g2 = _get_model(G2, sig)
        g1.eval(); g2.eval()

        def _dev(model):
            try:
                return next(model.parameters()).device
            except StopIteration:
                return torch.device("cpu")
        dev1, dev2 = _dev(g1), _dev(g2)

        # Inputs for G1
        X_eval_vec_1 = torch.full((sample_size, 1), fill_value=x_eval, dtype=torch.float32, device=dev1)
        eta_eval_1   = torch.randn(sample_size, noise_dim, device=dev1)
        with torch.no_grad():
            Y_gen1 = g1(X_eval_vec_1, eta_eval_1).detach().cpu().numpy()

        # Inputs for G2
        X_eval_vec_2 = torch.full((sample_size, 1), fill_value=x_eval, dtype=torch.float32, device=dev2)
        eta_eval_2   = torch.randn(sample_size, noise_dim, device=dev2)
        with torch.no_grad():
            Y_gen2 = g2(X_eval_vec_2, eta_eval_2).detach().cpu().numpy()

        # Test data
        Y_test = generate_helix_given_X(sample_size, sig, x_eval)

        if save_csv:
            tag = fmt_sigma(sig)  # assumes fmt_sigma is defined elsewhere
            np.savetxt(os.path.join(save_dir_csv, f"gcds_generated_samples_sigma{tag}.csv"), Y_gen1, delimiter=",")
            np.savetxt(os.path.join(save_dir_csv, f"ecmmd_generated_samples_sigma{tag}.csv"), Y_gen2, delimiter=",")
            np.savetxt(os.path.join(save_dir_csv, f"test_samples_sigma{tag}.csv"), Y_test, delimiter=",")

        
        ax0, ax1, ax2 = axes[r]

        # GCDS 
        ax0.scatter(Y_gen1[:, 0], Y_gen1[:, 1], s=s, alpha=alpha)
        ax0.set_title(f"Generated by GCDS", fontsize=20)
        ax0.set_xlabel(r"$y_1$", fontsize=22)
        ax0.set_ylabel(r"$y_2$", fontsize=22)
        ax0.grid(True)

        # ECMMD 
        ax1.scatter(Y_gen2[:, 0], Y_gen2[:, 1], s=s, alpha=alpha)
        ax1.set_title(f"Generated by ECMMD", fontsize=20)
        ax1.set_xlabel(r"$y_1$", fontsize=22)
        ax1.set_ylabel(r"$y_2$", fontsize=22)
        ax1.grid(True)

        # Test 
        ax2.scatter(Y_test[:, 0], Y_test[:, 1], s=s, alpha=alpha)
        ax2.set_title(f"Test Samples", fontsize=20)
        ax2.set_xlabel(r"$y_1$", fontsize=22)
        ax2.set_ylabel(r"$y_2$", fontsize=22)
        ax2.grid(True)

        label_prefix = f"({chr(97 + r)}) "  # (a), (b), (c), ...
        ax1.annotate(
            rf"{label_prefix}$\sigma = {sig}$",
            xy=(0.5, -ROW_LABEL_OFFSET), xycoords="axes fraction",
            ha="center", va="top",
            fontsize=ROW_LABEL_FONTSIZE
        )

    plt.tight_layout()
    fig.subplots_adjust(hspace=ROW_GAP_HSPACE, bottom=BOTTOM_MARGIN)

    if nrows == 1:
        out_pdf = os.path.join(save_dir_fig, f"{panel_basename}_sigma={sigmas[0]}.pdf")
        plt.savefig(out_pdf, format="pdf", bbox_inches="tight")
    else:
        tag_all = "_".join(fmt_sigma(s) for s in sigmas)
        out_pdf = os.path.join(save_dir_fig, f"{panel_basename}_sigmas_{tag_all}.pdf")
        plt.savefig(out_pdf, format="pdf", bbox_inches="tight")

    plt.show()

    if return_fig:
        return fig



